import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
import math
import sys

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        model = sm.OLS.from_formula(outcome+' ~ '+pbvarb, data = dataset).fit(cov_type = 'HC1')
        #model =  smf.ols(outcome + ' ~ ' + pbvarb, dataset).fit(cov_type='HC1')
        coef = model.params[pbvarb]
        se = model.bse[pbvarb]
        black_aud_rate = model.params[0] + model.params[1]
        non_black_aud_rate = model.params[0] 
        print("number of obs :" + str(model.nobs))
        return coef, se, black_aud_rate, non_black_aud_rate

## Probabilistic Estimator 
def chenEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        black_aud_rate = (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()
        non_black_aud_rate = ((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()
        est =  black_aud_rate - non_black_aud_rate
        return est, black_aud_rate, non_black_aud_rate

## Get standard errors
def getSEs(dataset,pbvarb='predicted_prob_black',outcome='audited', seReg = 0.0):
        seMultiplier=getSEMultiplier(dataset,pbvarb)
        #seReg = regEstimate(dataset,pbvarb,outcome)[1]
        seChen = seReg*seMultiplier
        return seReg, seChen

def getSEMultiplier(dataset,pbvarb='predicted_prob_black'):
        return np.sqrt(dataset[pbvarb].var()/(dataset[pbvarb].mean()*(1-dataset[pbvarb].mean())))


### Operational audits and predicted race probabilitiess
keep_cols = ['taxpayer_id_new', 'predicted_prob_black','aud_no_research_audits', 'dep_database_aud', 'base_wgt', 'isEIC', 'audited']
dataBISG = pd.read_csv("/REDACTED/data/final/individualBISG2014_full_final.csv", usecols=keep_cols)
dataBISG= dataBISG[~dataBISG.predicted_prob_black.isna()]

extract = pd.read_stata('/REDACTED/HertzGraff/cam_eic_extract.dta')

extract = extract.rename(columns={"taxpayer_id":"taxpayer_id_new", "taxpayer_id_typ":"taxpayer_id_typ_new", "cycle_posted":"cycle_posted_new", "eic":"eic_new", "eitc_amt_computer":"eitc_amt_computer_new"})
dataBISG = dataBISG.merge(extract, on='taxpayer_id_new', how='left')
dataBISG['isEIC_new'] = np.where(dataBISG['eic_new'] > 0, 1, 0)


### EITC
eitc_pop = dataBISG.loc[dataBISG['isEIC_new'] == 1]
non_eitc_pop = dataBISG.loc[dataBISG['isEIC_new'] == 0]


## Decomposition (appendix D.3)
D_C = chenEstimate(eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[0]
D_NC = chenEstimate(non_eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[0]
C_B = chenEstimate(dataBISG, pbvarb = 'predicted_prob_black', outcome = 'isEIC_new')[1]
C_NB = chenEstimate(dataBISG, pbvarb = 'predicted_prob_black', outcome = 'isEIC_new')[2]
Y_C_NB = chenEstimate(eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[2]
Y_NC_NB = chenEstimate(non_eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[2]
Y_C_B = chenEstimate(eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[1]
Y_NC_B = chenEstimate(non_eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[1]

dataBISG.isEIC_new.mean()
C_B
C_NB
eitc_pop.aud_no_research_audits.mean()
Y_C_B
Y_C_NB
non_eitc_pop.aud_no_research_audits.mean()
Y_NC_B
Y_NC_NB


# if linear, use these estimates instead
D_C = regEstimate(eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[0]
D_NC = regEstimate(non_eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[0]
C_B = regEstimate(dataBISG, pbvarb = 'predicted_prob_black', outcome = 'isEIC_new')[2]
C_NB = regEstimate(dataBISG, pbvarb = 'predicted_prob_black', outcome = 'isEIC_new')[3]
Y_C_NB = regEstimate(eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[3]
Y_NC_NB = regEstimate(non_eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[3]
Y_C_B = regEstimate(eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[2]
Y_NC_B = regEstimate(non_eitc_pop, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[2]

dataBISG.isEIC_new.mean()
C_B
C_NB
eitc_pop.aud_no_research_audits.mean()
Y_C_B
Y_C_NB
non_eitc_pop.aud_no_research_audits.mean()
Y_NC_B
Y_NC_NB

D = chenEstimate(dataBISG, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[0]
D_L = regEstimate(dataBISG, pbvarb = 'predicted_prob_black', outcome = 'aud_no_research_audits')[0]


## reference group: Black
term1 = D_C * C_B
term2 = D_NC * (1 - C_B)
term3 = (C_B - C_NB) * (Y_C_NB - Y_NC_NB)

term1/D
term2/D
term3/D

## if linear (because the terms don't add up to D_L)
term1/(term1+term2+term3)
term2/(term1+term2+term3)
term3/(term1+term2+term3)

## reference group: Non-Black
term1 = D_C * C_NB
term2 = D_NC * (1 - C_NB)
term3 = -(C_NB - C_B) * (Y_C_B - Y_NC_B)

term1/D
term2/D
term3/D

## if linear (because the terms don't add up to D_L)
term1/(term1+term2+term3)
term2/(term1+term2+term3)
term3/(term1+term2+term3)


# Tom's decomp
Y_C = eitc_pop.aud_no_research_audits.mean()
Y_NC = non_eitc_pop.aud_no_research_audits.mean()

term1 = (Y_C_B - Y_C)*C_B + (Y_C - Y_C_NB)*C_NB
term2 = (Y_NC_B - Y_NC)*(1 - C_B) + (Y_NC-Y_NC_NB)*(1-C_NB)
term3 = (C_B-C_NB)*Y_C + (C_NB-C_B)*Y_NC

term1/D
term2/D
term3/D

## if linear (because the terms don't add up to D_L)
term1/(term1+term2+term3)
term2/(term1+term2+term3)
term3/(term1+term2+term3)